## master preamble ------------------------------------------------------
rm(list=ls(all=TRUE)) 

## inherit directories from stata -----------
args = commandArgs(trailingOnly = "TRUE")
if (length(args)) {
  dir.proj = paste0(args[1],'/')
  dir.lib  = args[2]
} 

## if loading from library -------------------
if(dir.lib!='') {
  ## load depenencies --------------------
  library(R.utils,lib.loc=dir.lib)
  library(ggplot2,lib.loc=dir.lib)
  library(tidyr,lib.loc=dir.lib)
  library(stringr,lib.loc=dir.lib)
  library(forcats,lib.loc=dir.lib)
  library(lubridate,lib.loc=dir.lib)
  library(haven,lib.loc=dir.lib)
  library(data.table,lib.loc=dir.lib)
  ## load main packages ------------------
  library(rio,lib.loc=dir.lib)
  library(dplyr,lib.loc=dir.lib)
  library(tidyverse,lib.loc=dir.lib)
  library(fixest,lib.loc=dir.lib)
## otherwise: simple load -----------------
} else {
  library(rio)
  library(dplyr)
  library(tidyverse)
  library(fixest)
}
## --------------------------------------------------------------------

## file preamble ------------------------------------------------------
## randomization seed --------
set.seed(413)

## directories --------
dir.data = paste0(dir.proj,'data/out/')
dir.est  = paste0(dir.proj,'estimates/')
dir.temp = paste0(dir.proj,'_temp/')

## estimation --------
Q = 2
bootrep  = 250
name.out = 'speeds_q2'
y.list = paste0('speed_',seq(5,30,5))

## functions ----------------
rudirichlet <- function(n, d) {
  rexp.mat <- matrix( rexp(d * n, 1) , nrow = n, ncol = d)
  dirichlet.weights <- rexp.mat / rowSums(rexp.mat)
  dirichlet.weights }
## --------------------------------------------------------------------


## setup data -------------------
data = import(paste0(dir.temp,'speeds_data.dta')) %>%
  mutate(D=harsh,W=covbin1,jid=officerid) %>%
  select(citationid,D,Z,W,totfe,jid,year,all_of(y.list))

## generic polynomials in instrument -------
data = mutate(
  data,
  Z1=Z^1,Z2=Z^2,Z3=Z^3,Z4=Z^4,Z5=Z^5,Z6=Z^6,Z7=Z^7,Z8=Z^8)

## setup for clustered bootstrap -------------------
list    = unique(select(data,jid))
weights = rudirichlet(nrow(list),bootrep)*10


## Looping -----------
for(j in 0:bootrep) {
  
  ## setup bootstrap data --------
  if(j==0) {
    boot.list = list %>% mutate(bootwt = 1)
    boot.data = inner_join(data,boot.list,by='jid') }
  if(j>=1) {
    boot.list = list %>% mutate(bootwt = weights[,j])
    boot.data = inner_join(data,boot.list,by='jid') }
  
  ## start outcomes loop ---------------
  for(k in 1:length(y.list)) {
    
    ## set outcome ---------
    boot.data = mutate(boot.data,Y=boot.data[[y.list[k]]])
    
    ## for count outcomes, use only 2015 and before -----
    if(k>=4) {
      boot.data = filter(boot.data,year<=2015)
    }
    
    ## get sample means ------------
    y = weighted.mean(boot.data$Y,boot.data$bootwt)
    p = weighted.mean(boot.data$D,boot.data$bootwt)
    
    ## polynomial fit --------------
    fit = fixest::feols(
      as.formula(glue::glue('Y~',paste(paste0('Z',seq(1:Q)),collapse='+'),'|totfe')),
      boot.data,weights=~bootwt)
    
    ## compute extrapolation ests ------------
    fe = stats::predict(fit,newdata=boot.data,fixef=T)
    y0 = mean(fe$totfe)
    y1 = mean(fe$totfe)+sum(fit$coefficients)
    
    ## compute other pars -------------
    fit = feols(Y~D,boot.data,weights=~bootwt)
    y0d0 = as.numeric(fit$coefficients[1])
    y1d1 = as.numeric(fit$coefficients[1]+fit$coefficients[2])
    
    ## construct data with estimates --------
    est = data.frame(
      par=c('Y','p','Y0','Y1','ATE','ATT','ATU',
            'Y1D1','Y1D0','Y0D0','Y0D1','Y0D1mY0D0','ATTmATU'),
      est=c(y,p,y0,y1,(y1-y0),(y-y0)/p,(y1-y)/(1-p),
            y1d1,(1/(1-p))*y1-(p/(1-p))*y1d1,
            y0d0,(1/p)*y0-((1-p)/p)*y0d0,
            ((1/p)*y0-((1-p)/p)*y0d0)-y0d0,
            ((y-y0)/p)-(y1-y)/(1-p))) %>% mutate(iter=j,yvar=y.list[k])
    
    ## compile output --------------
    if(j==0 & k==1) {
      boot.out = est 
    } else {
      boot.out = bind_rows(boot.out,est) }
    
  }
  
  ## progress report -----
  print(paste('Iteration',j,'of',bootrep,'Completed'))
  
}
## end of bootstrap -------------------------


  
## Bootstrap output -----------
est.main = boot.out %>% filter(iter==0) %>% select(yvar,par,est)
est.boot = boot.out %>% filter(iter>=1) %>%
  group_by(yvar,par) %>% summarize(se=sd(est),lq=quantile(est,0.05),uq=quantile(est,0.95))
est.out  = inner_join(est.main,est.boot) %>% mutate(lb=est-1.96*se,ub=est+1.96*se)
export(est.out,paste0(dir.est,name.out,'.csv'))
    
    
    
    